import os
import json
import time
from tqdm import tqdm
import fire

import concurrent.futures
import random
import time
from collections import deque
from functools import partial

from PIL import Image
import base64
from io import BytesIO

from google import genai

# Few-shot examples file (referring style)
EXAMPLE_FILE = os.path.join(os.path.dirname(__file__), "question_example_referring_segmentation.json")


def _load_examples(file_path):
    """Load few-shot examples from JSON file and extract required fields."""
    if not os.path.exists(file_path):
        return []
    try:
        with open(file_path, "r", encoding="utf-8") as f:
            data = json.load(f)
        examples = []
        for item in data:
            q = item.get("question", "")
            examples.append({
                "question": q,
                "answer": item.get("answer", ""),
                "captions": item.get("captions", ""),
                "sam_input": item.get("sam_input", ""),
            })
        return examples
    except Exception:
        return []


EXAMPLES = _load_examples(EXAMPLE_FILE)


gemini_model = "gemini-2.5-flash"  

class TPMManager:
    def __init__(self, target_tpm=1000000, window_size=60):
        self.target_tpm = target_tpm
        self.window_size = window_size
        self.request_times = deque(maxlen=1000)
        self.last_tpm_report = 0
        self.report_interval = 10
        self.min_delay = 0.001

    def record_request(self):
        self.request_times.append(time.time())

    def get_current_tpm(self):
        now = time.time()
        while self.request_times and now - self.request_times[0] > self.window_size:
            self.request_times.popleft()
        return len(self.request_times) * (60 / self.window_size)

    def should_report_tpm(self):
        now = time.time()
        if now - self.last_tpm_report >= self.report_interval:
            self.last_tpm_report = now
            return True
        return False

    def calculate_delay(self):
        current_tpm = self.get_current_tpm()
        if self.should_report_tpm():
            print(f"Current TPM: {current_tpm:.2f}, Target TPM: {self.target_tpm}")
        if current_tpm < self.target_tpm:
            return self.min_delay
        delay = (current_tpm / self.target_tpm) * (60 / self.target_tpm)
        return max(self.min_delay, delay)

tpm_manager = TPMManager(target_tpm=1000000)

def _exp_delay(attempt: int, base: float = 1.0, factor: float = 2.0, max_delay: float = 60.0) -> float:
    return min(base * (factor ** attempt), max_delay)

def ask_gemini(messages, api_key, max_retries=5, temperature=0.7, top_p=0.9, max_tokens=4096):
    if not api_key:
        raise ValueError("Gemini API key is required.")
    
    genai.configure(api_key=api_key)
    client = genai.GenerativeModel(gemini_model)

    if isinstance(messages, str):
        content = messages
    else:
        content_parts = []
        for msg in messages:
            role = msg.get('role', 'user')
            content_part = msg.get('content', '')
            if role == 'system':
                # Gemini API's native format uses system instructions differently
                # We prepend it to the user message for compatibility
                content_parts.append(f"System Instruction: {content_part}")
            elif role == 'user':
                content_parts.append(f"User: {content_part}")
            elif role == 'assistant':
                content_parts.append(f"Assistant: {content_part}")
        content = "\n".join(content_parts)

    for attempt in range(max_retries):
        try:
            delay = tpm_manager.calculate_delay()
            if delay > 0:
                time.sleep(delay)
            tpm_manager.record_request()
            
            response = client.generate_content(content)
            
            if hasattr(response, 'text') and response.text:
                return response.text
            else:
                print(f"Gemini API response format error: {response}")
                return None
        except Exception as e:
            print(f"Gemini API request exception: {e}, retry attempt {attempt + 1}...")
            time.sleep(_exp_delay(attempt, base=2.0))
    
    print(f"Failed after {max_retries} retry attempts")
    return None

def generate_worker(image_path, api_key, caption_override=None, examples=EXAMPLES):
    captions_str = caption_override
    
    # 1. captions + sam_input_examples -> sam_input
    messages = [
        {'role': 'system', 'content': "You are an AI visual remote sensing assistant. You receive examples of a caption and the desired input sentence for a segmentation model. Based on a new caption, generate a similar input sentence. The input sentence should be a concise command to segment objects. JUST provide the generated input sentence."},
        {"role": "user", "content": examples[0]['captions']},
        {"role": "assistant", "content": examples[0]['sam_input']},
        {"role": "user", "content": examples[1]['captions']},
        {"role": "assistant", "content": examples[1]['sam_input']},
        {"role": "user", "content": examples[2]['captions']},
        {"role": "assistant", "content": examples[2]['sam_input']},
        {"role": "user", "content": examples[3]['captions']},
        {"role": "assistant", "content": examples[3]['sam_input']},
        {"role": "user", "content": captions_str},
    ]
    sam_input = ask_gemini(messages, api_key=api_key)
    if sam_input is None:
        print("sam_input generation failed.")
        return None
    
    # 2. captions + questions_examples + sam_input -> question
    messages = [
        {'role': 'system', 'content': "You are an AI visual remote sensing assistant. You receive examples of a caption and a segmentation command, along with the desired user-facing question. Based on a new caption and segmentation command, generate a similar user-facing question. JUST provide the generated question."},
        {"role": "user", "content": f"{examples[0]['captions']}\n{examples[0]['sam_input']}"},
        {"role": "assistant", "content": examples[0]['question']},
        {"role": "user", "content": f"{examples[1]['captions']}\n{examples[1]['sam_input']}"},
        {"role": "assistant", "content": examples[1]['question']},
        {"role": "user", "content": f"{examples[2]['captions']}\n{examples[2]['sam_input']}"},
        {"role": "assistant", "content": examples[2]['question']},
        {"role": "user", "content": f"{examples[3]['captions']}\n{examples[3]['sam_input']}"},
        {"role": "assistant", "content": examples[3]['question']},
        {"role": "user", "content": f"{captions_str}\n{sam_input}"},
    ]
    question = ask_gemini(messages, api_key=api_key)
    if question is None:
        print("question generation failed.")
        return None
    
    # 3. Generate answer
    q_temp = "caption: {cap}\nquestion: {q}\n"
    answer_prompt = [
        {'role': 'system', 'content': "You are an AI visual remote sensing assistant. You receive examples of a caption and a question, along with a suitable answer. Based on a new caption and question, generate a similar suitable answer confirming you can perform the task."},
        {"role": "user", "content": q_temp.format(cap=examples[0]['captions'], q=examples[0]['question'])},
        {"role": "assistant", "content": examples[0]['answer']},
        {"role": "user", "content": q_temp.format(cap=examples[1]['captions'], q=examples[1]['question'])},
        {"role": "assistant", "content": examples[1]['answer']},
        {"role": "user", "content": q_temp.format(cap=examples[2]['captions'], q=examples[2]['question'])},
        {"role": "assistant", "content": examples[2]['answer']},
        {"role": "user", "content": q_temp.format(cap=examples[3]['captions'], q=examples[3]['question'])},
        {"role": "assistant", "content": examples[3]['answer']},
        {"role": "user", "content": q_temp.format(cap=captions_str, q=question)},
    ]
    answer = ask_gemini(answer_prompt, api_key=api_key)
    if answer is None:
        print("Answer generation failed.")
        return None

    return {
        "question": question,
        "answer": answer,
        "sam_input": sam_input,
    }

def _load_grounding_dataset(refs_file_path):
    """Loads the dataset annotations from the provided JSON file path."""
    refs_map = {}
    imageid2annid = {}
    annid2filename = {}

    if not os.path.exists(refs_file_path):
        print(f"Error: Annotation file not found at {refs_file_path}")
        return None, None, None

    try:
        with open(refs_file_path, "r", encoding="utf-8") as f:
            refs_data = json.load(f)
        for item in refs_data:
            ann_id = item.get("ann_id") or item.get("id") or item.get("image_id")
            if ann_id is None:
                continue
            
            caption = ""
            if "sentences" in item and item["sentences"]:
                sent0 = item["sentences"][0]
                caption = sent0.get("raw") or sent0.get("sent") or ""
            else:
                caption = item.get("raw") or item.get("sent") or ""
            
            ann_str = str(ann_id)
            refs_map[ann_str] = caption.strip()
            imageid2annid[str(item["image_id"])] = ann_str
            if "file_name" in item:
                annid2filename[ann_str] = item["file_name"]
    except Exception as e:
        print(f"Failed to load or parse refs file: {e}")
        return None, None, None
        
    return refs_map, imageid2annid, annid2filename

def get_caption(image_id, refs_map, imageid2annid):
    """Get caption based on image_id from pre-loaded data."""
    ann_id = imageid2annid.get(str(image_id))
    if ann_id is None:
        return None
    return refs_map.get(ann_id)

def process_single_image(image_info, api_key, refs_map, imageid2annid):
    image_path = image_info['image_path']
    image_id = image_info.get('image_id', os.path.basename(image_path))
    
    print(f"Processing image: {image_path}")
    
    caption = get_caption(image_id, refs_map, imageid2annid)
    if caption is None:
        print(f"Caption not found for image_id: {image_id}")
        return None

    generated_content = generate_worker(image_path, api_key, caption_override=caption)
    
    if generated_content is None:
        print(f"Content generation failed for image: {image_path}")
        return None
    
    unique_id = f"referring_segmentation_{image_id}_{int(time.time())}"
    
    data_item = {
        "unique_id": unique_id,
        "image_id": image_id,
        "image_file_name": os.path.basename(image_path),
        "image_path": image_path,
        "question": generated_content["question"],
        "caption": caption,
        "answer": generated_content["answer"],
        "sam_input": generated_content["sam_input"],
        "data_source": "referring_segmentation",
        "timestamp": time.time()
    }
    return data_item

def generate_panoptic_data(
    image_dir,
    output_file,
    annotations_refs_path,
    api_key,
    max_images=None,
    num_workers=4,
    start_index=0,
):
    refs_map, imageid2annid, annid2filename = _load_grounding_dataset(annotations_refs_path)
    if not all([refs_map, imageid2annid, annid2filename]):
        print("Could not load dataset. Aborting.")
        return

    image_files = []
    for img_id, ann_id in imageid2annid.items():
        file_name = annid2filename.get(ann_id)
        if not file_name:
            continue
        image_path = os.path.join(image_dir, file_name)
        if os.path.exists(image_path):
            image_files.append({'image_path': image_path, 'image_id': img_id})

    random.shuffle(image_files)

    if start_index > 0:
        image_files = image_files[start_index:]
    if max_images is not None:
        image_files = image_files[:max_images]

    print(f"Found {len(image_files)} valid image/annotation pairs (starting from index {start_index}).")
    
    os.makedirs(os.path.dirname(output_file), exist_ok=True)
    open(output_file, 'w', encoding='utf-8').close()

    failed_count = 0
    
    # Use functools.partial to pass fixed arguments to the worker function
    worker_func = partial(process_single_image, api_key=api_key, refs_map=refs_map, imageid2annid=imageid2annid)

    with concurrent.futures.ThreadPoolExecutor(max_workers=num_workers) as executor:
        future_to_image = {
            executor.submit(worker_func, image_info): image_info for image_info in image_files
        }
        
        for future in tqdm(concurrent.futures.as_completed(future_to_image), total=len(image_files)):
            image_info = future_to_image[future]
            try:
                result = future.result()
                if result:
                    with open(output_file, 'a', encoding='utf-8') as f:
                        f.write(json.dumps(result, ensure_ascii=False) + '\n')
                else:
                    failed_count += 1
                    # print(f"Processing failed: {image_info['image_path']}")
            except Exception as e:
                failed_count += 1
                print(f"An exception occurred for {image_info['image_path']}: {e}")

    success_count = sum(1 for _ in open(output_file, 'r', encoding='utf-8')) if os.path.exists(output_file) else 0
    print(f"\nProcessing complete! Success: {success_count}, Failed: {failed_count}")
    print(f"Results saved to: {output_file}")

def main(
    image_dir: str,
    annotations_refs_path: str,
    gemini_api_key: str,
    output_file: str = "data.jsonl",
    max_images: int = None,
    num_workers: int = 4,
    start_index: int = 0
):
    """
    Main function to generate panoptic detection data.

    Args:
        image_dir (str): Path to the directory containing images.
        annotations_refs_path (str): Path to the JSON file with annotations (captions).
        gemini_api_key (str): Your Google Gemini API key.
        output_file (str): Path to the output JSONL file. Defaults to "data.jsonl".
        max_images (int, optional): Maximum number of images to process. Defaults to None.
        num_workers (int): Number of parallel workers to use. Defaults to 4.
        start_index (int): Index of the image to start processing from. Defaults to 0.
    """
    print(f"Starting data generation...")
    print(f"Image directory: {image_dir}")
    print(f"Annotation file: {annotations_refs_path}")
    print(f"Output file: {output_file}")
    print(f"Number of workers: {num_workers}")
    print(f"Start index: {start_index}")
    
    generate_panoptic_data(
        image_dir=image_dir,
        output_file=output_file,
        annotations_refs_path=annotations_refs_path,
        api_key=gemini_api_key,
        max_images=max_images,
        num_workers=num_workers,
        start_index=start_index
    )

if __name__ == "__main__":
    fire.Fire(main)